{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm_notebook\n",
    "import pickle\n",
    "import os\n",
    "import logging\n",
    "import time\n",
    "from IPython.core.debugger import set_trace\n",
    "import glob\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from utils.utils import save_checkpoint, load_checkpoint, set_logger\n",
    "from utils.gpu_utils import set_n_get_device\n",
    "\n",
    "from dataset.dataset import prepare_trainset\n",
    "\n",
    "#from model.model_unet_classify_zero import UNetResNet34\n",
    "from model.model_unet_classify_zero2 import UNetResNet34\n",
    "from model.deeplab_model_kaggler.lr_scheduler import LR_Scheduler\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "====MODEL ACHITECTURE: UNetResNet34====\n"
     ]
    }
   ],
   "source": [
    "######### Config the training process #########\n",
    "#device = set_n_get_device(\"0, 1, 2, 3\", data_device_id=\"cuda:0\")#0, 1, 2, 3, IMPORTANT: data_device_id is set to free gpu for storing the model, e.g.\"cuda:1\"\n",
    "MODEL = 'UNetResNet34'#'RESNET34', 'RESNET18', 'INCEPTION_V3', 'BNINCEPTION', 'SEResnet50'\n",
    "#AUX_LOGITS = True#False, only for 'INCEPTION_V3'\n",
    "print('====MODEL ACHITECTURE: %s===='%MODEL)\n",
    "\n",
    "device = set_n_get_device(\"0\", data_device_id=\"cuda:0\")#0, 1, 2, 3, IMPORTANT: data_device_id is set to free gpu for storing the model, e.g.\"cuda:1\"\n",
    "multi_gpu = None #[0, 1] #None\n",
    "\n",
    "SEED = 1234#5678#4567#3456#2345#1234\n",
    "debug = True# if True, load 100 samples\n",
    "IMG_SIZE = (256,1600)\n",
    "BATCH_SIZE = 16\n",
    "NUM_WORKERS = 24\n",
    "torch.cuda.manual_seed_all(SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Count images in train/valid folder:  12568 1801\n",
      "Count of trainset (for training):  1120\n",
      "Count of validset (for training):  120\n"
     ]
    }
   ],
   "source": [
    "train_dl, val_dl = prepare_trainset(BATCH_SIZE, NUM_WORKERS, SEED, IMG_SIZE, debug)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i, (images, masks) in enumerate(train_dl):\n",
    "    images = images.to(device=device, dtype=torch.float)\n",
    "    #1 for non-zero-mask\n",
    "    truth = (torch.sum(masks.reshape(masks.size()[0], masks.size()[1], -1), dim=2, keepdim=False)!=0).to(device=device, dtype=torch.float)\n",
    "    if i==0:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 1, 256, 1600]), torch.Size([16, 4]))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "images.size(), truth.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 1., 0.],\n",
       "        [0., 0., 0., 0.],\n",
       "        [0., 0., 1., 0.]], device='cuda:0')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = UNetResNet34(debug=True).cuda(device=device)\n",
    "#net = ZeroMaskClassifier(debug=True).cuda(device=device)\n",
    "\n",
    "#torch.cuda.set_device(0)\n",
    "#torch.distributed.init_process_group(backend='nccl', world_size=4, init_method='...')\n",
    "#net = DistributedDataParallel(net, device_ids=[0], output_device=0)\n",
    "#torch.distributed.init_process_group(backend=\"nccl\")\n",
    "\n",
    "#checkpoint_path = 'checkpoint/UNetResNet34_512_v1_seed3456/best.pth.tar'\n",
    "#net, _ = load_checkpoint(checkpoint_path, net)\n",
    "\n",
    "if multi_gpu is not None:\n",
    "    net = nn.DataParallel(net, device_ids=multi_gpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input:  torch.Size([16, 3, 256, 1600])\n",
      "e1 torch.Size([16, 64, 128, 800])\n",
      "e2 torch.Size([16, 64, 128, 800])\n",
      "e3 torch.Size([16, 128, 64, 400])\n",
      "e4 torch.Size([16, 256, 32, 200])\n",
      "e5 torch.Size([16, 512, 16, 100])\n",
      "avgpool torch.Size([16, 512, 1, 1])\n",
      "reshape torch.Size([16, 512])\n",
      "logit_clf torch.Size([16, 4])\n"
     ]
    }
   ],
   "source": [
    "logit = net(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.8676, device='cuda:0', grad_fn=<BinaryCrossEntropyWithLogitsBackward>)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if multi_gpu:\n",
    "    loss = net.module.criterion(logit, truth)\n",
    "else:\n",
    "    loss = net.criterion(logit, truth)\n",
    "\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0.26415, 28, 25, 11, 0, 0.390625)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "if multi_gpu:\n",
    "    _, metric, tn, fp, fn, tp, pos_percent = net.module.metric(logit, truth)\n",
    "else:\n",
    "    _, metric, tn, fp, fn, tp, pos_percent = net.metric(logit, truth)\n",
    "\n",
    "metric, tn, fp, fn, tp, pos_percent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "#label_df = pd.read_csv('data/raw/NIH/NIH external data/label_df_pneumothorax.csv').set_index('img_id')\n",
    "\n",
    "for idx, img_id in enumerate(label_df.index):\n",
    "    if idx>10000:\n",
    "        break\n",
    "    has_pneumothorax = label_df.loc[img_id, 'has_pneumothorax']\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for idx, f in enumerate(glob.glob('data/raw/NIH/NIH external data/images*/*')):\n",
    "    img = plt.imread(f)\n",
    "    if len(img.shape)==3:\n",
    "        print(img.shape)\n",
    "        break\n",
    "    if idx>100:\n",
    "        break\n",
    "#img_path = glob.glob('data/raw/NIH/NIH external data/images*/00000001_000.png')[0]#00000003_000.png\n",
    "#img = plt.imread(img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#plt.imshow(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## predict the validset, and analyse  \n",
    "**first, predict zero-nonzero-mask**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#move checkpoint from gamma machine to here\n",
    "cd checkpoint\n",
    "scp -r endi.niu@10.171.36.214:/home/endi.niu/SIIM/checkpoint/nonzero_classifier_UNetResNet34_768_v1_seed1234/ nonzero_classifier_UNetResNet34_768_v1_seed1234\n",
    "cd logging\n",
    "scp -r endi.niu@10.171.36.214:/home/endi.niu/SIIM/logging/nonzero_classifier_UNetResNet34_768_v1_seed1234.log nonzero_classifier_UNetResNet34_768_v1_seed1234.log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import math\n",
    "from matplotlib import pyplot as plt\n",
    "from tqdm import tqdm_notebook\n",
    "import pickle\n",
    "import os\n",
    "import logging\n",
    "import time\n",
    "from IPython.core.debugger import set_trace\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from utils.utils import save_checkpoint, load_checkpoint, set_logger\n",
    "from utils.gpu_utils import set_n_get_device\n",
    "\n",
    "from model.model_unet_classify_zero import UNetResNet34, predict_proba\n",
    "from dataset.dataset import prepare_trainset\n",
    "\n",
    "from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "######### Config the training process #########\n",
    "#device = set_n_get_device(\"0, 1, 2, 3\", data_device_id=\"cuda:0\")#0, 1, 2, 3, IMPORTANT: data_device_id is set to free gpu for storing the model, e.g.\"cuda:1\"\n",
    "MODEL = 'UNetResNet34'#'RESNET34', 'RESNET18', 'INCEPTION_V3', 'BNINCEPTION', 'SEResnet50'\n",
    "#AUX_LOGITS = True#False, only for 'INCEPTION_V3'\n",
    "print('====MODEL ACHITECTURE: %s===='%MODEL)\n",
    "\n",
    "device = set_n_get_device(\"0,1\", data_device_id=\"cuda:0\")#0, 1, 2, 3, IMPORTANT: data_device_id is set to free gpu for storing the model, e.g.\"cuda:1\"\n",
    "multi_gpu = [0, 1]#use 2 gpus\n",
    "\n",
    "SEED = 1234#5678#4567#3456#2345#1234\n",
    "debug = False# if True, load 100 samples\n",
    "IMG_SIZE = (256,1600)\n",
    "BATCH_SIZE = 32\n",
    "NUM_WORKERS = 24"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "train_dl, val_dl = prepare_trainset(BATCH_SIZE, NUM_WORKERS, SEED, IMG_SIZE, debug)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# y should be makeup\n",
    "y_valid = []\n",
    "for i, (images, masks) in enumerate(val_dl):\n",
    "    #if i==10:\n",
    "    #    break\n",
    "    #truth = masks.to(device=device, dtype=torch.float)\n",
    "    truth = (torch.sum(masks.reshape(masks.size()[0], masks.size()[1], -1), dim=2, keepdim=False)!=0)#.to(device=device, dtype=torch.float)\n",
    "    y_valid.append(truth.numpy())\n",
    "y_valid = np.concatenate(y_valid, axis=0)\n",
    "y_valid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_valid.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = UNetResNet34(debug=False).cuda(device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_path = '../checkpoint/nonzero_classifier_UNetResNet34_256x1600_v2_seed1234/best.pth.tar'\n",
    "net, _ = load_checkpoint(checkpoint_path, net)\n",
    "\n",
    "if multi_gpu is not None:\n",
    "    net = nn.DataParallel(net, device_ids=multi_gpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "preds_valid = predict_proba(net, val_dl, device, multi_gpu=multi_gpu, mode='valid', tta=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_valid.shape, preds_valid.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import roc_auc_score, confusion_matrix\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x):\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "def cal_metric(logit, truth):\n",
    "    #pred = sigmoid(logit.cpu().detach())\n",
    "    pred = sigmoid(logit)\n",
    "    #truth = truth.cpu().detach().numpy()\n",
    "    ##\n",
    "    THRESHOLD_candidate = [0.5]#np.arange(0.01, 0.99, 0.01)\n",
    "    N = len(THRESHOLD_candidate)\n",
    "    best_threshold = [0.5, 0.5, 0.5, 0.5]\n",
    "    best_score = -1\n",
    "    tn, fp, fn, tp, pos_percent = 0, 0, 0, 0, 0.0\n",
    "\n",
    "    for ch in range(4):\n",
    "        for i in range(N):\n",
    "            THRESHOLD = copy.deepcopy(best_threshold)\n",
    "            THRESHOLD[ch] = THRESHOLD_candidate[i]\n",
    "            _pred = pred>THRESHOLD\n",
    "            _pred, truth = _pred.reshape(-1, 1), truth.reshape(-1, 1)\n",
    "            \n",
    "            _tn, _fp, _fn, _tp = confusion_matrix(truth, _pred).ravel()\n",
    "            _auc = round(roc_auc_score(truth, _pred), 5)\n",
    "            if _tn+_fn==0:\n",
    "                fn_rate = 9999\n",
    "            else:\n",
    "                fn_rate = round(_fn/(_tn+_fn), 5)\n",
    "            _pos_percent = (_tp+_fp)/(_tp+_fp+_tn+_fn)\n",
    "            \n",
    "            if _auc > best_score:\n",
    "                best_threshold = copy.deepcopy(THRESHOLD)\n",
    "                best_score = _auc\n",
    "                tn, fp, fn, tp, pos_percent = _tn, _fp, _fn, _tp, _pos_percent\n",
    "    return best_threshold, best_score, tn, fp, fn, tp, pos_percent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_threshold, best_score, tn, fp, fn, tp, pos_percent = cal_metric(preds_valid, y_valid)\n",
    "best_threshold, best_score, tn, fp, fn, tp, pos_percent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_threshold, best_score, tn, fp, fn, tp, pos_percent = cal_metric(preds_valid, y_valid)\n",
    "best_threshold, best_score, tn, fp, fn, tp, pos_percent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# THRESHOLD_candidate = np.arange(0.01, 0.99, 0.01)#np.arange(0.1, 1.0, 0.1)\n",
    "\n",
    "# N = len(THRESHOLD_candidate)\n",
    "# best_threshold = None\n",
    "# best_score = 9999\n",
    "\n",
    "# for i in tqdm_notebook(range(N)):\n",
    "#     THRESHOLD = THRESHOLD_candidate[i]\n",
    "#     auc, fn_rate, tn, fp, fn, tp = cal_metric(preds_valid, y_valid, THRESHOLD)\n",
    "#     #predict positive proportion should be ~28%\n",
    "#     pos_percent = (tp+fp)/(tp+fp+tn+fn)\n",
    "#     if 0.2<pos_percent<0.28:\n",
    "#         flag = 'Yes'\n",
    "#         if fn_rate < best_score:\n",
    "#             best_threshold = THRESHOLD\n",
    "#             best_score = fn_rate\n",
    "#     else:\n",
    "#         flag = 'No'\n",
    "#     print('[%s]THRESHOLD: %.2f, metric_score: %.5f, pos_percent: %.3f; tn, fp, fn, tp, auc: %d, %d, %d, %d, %.5f'%(flag, THRESHOLD, fn_rate, pos_percent, tn, fp, fn, tp, auc))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# THRESHOLD = best_threshold\n",
    "# THRESHOLD, best_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(classification_report(y_valid, sigmoid(preds_valid)>best_threshold))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x):\n",
    "    return 1 / (1 + np.exp(-x))\n",
    "\n",
    "def predict_mask(logit, best_threshold):\n",
    "    \"\"\"Transform each prediction into mask.\n",
    "    input shape: (256, 256)\n",
    "    \"\"\"\n",
    "    #pred mask 0-1 pixel-wise\n",
    "    #n = logit.shape[0]\n",
    "    #IMG_SIZE = logit.shape[-1] #256\n",
    "    #EMPTY_THRESHOLD = 100.0*(IMG_SIZE/128.0)**2 #count of predicted mask pixles<threshold, predict as empty mask image\n",
    "    #MASK_THRESHOLD = 0.22\n",
    "    #logit = torch.sigmoid(torch.from_numpy(logit)).view(n, -1)\n",
    "    #pred = (logit>MASK_THRESHOLD).long()\n",
    "    #pred[pred.sum(dim=1) < EMPTY_THRESHOLD, ] = 0 #bug here, found it, the bug is input shape is (256, 256) not (16,256,256)\n",
    "    logit = sigmoid(logit)#.reshape(n, -1)\n",
    "    pred = (logit>best_threshold).astype(np.int)\n",
    "    return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "## visualize predicted masks\n",
    "start = 1\n",
    "total = 19\n",
    "\n",
    "fig=plt.figure(figsize=(15, 20))\n",
    "cnt = 0\n",
    "for idx, (images, masks) in enumerate(val_dl):\n",
    "    if idx<start:\n",
    "        continue\n",
    "    labels = (torch.sum(masks.reshape(masks.size()[0], -1), dim=1, keepdim=True)!=0).int()\n",
    "    for j in range(BATCH_SIZE):#BATCH_SIZE=8\n",
    "        cnt+=1\n",
    "        pred_mask = predict_mask(preds_valid[idx*BATCH_SIZE+j], best_threshold)\n",
    "        #if pred_mask.float().mean()==0:\n",
    "        #    continue\n",
    "        ax = fig.add_subplot(5, 4, cnt)\n",
    "        plt.imshow(images[j][0].numpy(), plt.cm.bone)\n",
    "        if labels[j]==1:\n",
    "            title = 'GT=1'\n",
    "        else:\n",
    "            title = 'GT=0'\n",
    "        if pred_mask==1:\n",
    "            title += ';PRED=1'\n",
    "        else:\n",
    "            title += ';PRED=0'\n",
    "        plt.title(title)\n",
    "        if cnt>total:\n",
    "            break\n",
    "    if cnt>total:\n",
    "            break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#100, 256, 334, 378, 547, 667, 916\n",
    "#47, 173\n",
    "plt.imshow(_logit[173], cmap='Reds')\n",
    "#plt.imshow(_logit[256]>0.001, cmap='Reds')\n",
    "#(_logit[100]>0.01).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#_logit = sigmoid(preds_valid_masks)#.reshape(n, -1)\n",
    "_pred = (_logit>MASK_THRESHOLD).astype(np.int)\n",
    "_pred_clf = (_pred.reshape(_pred.shape[0], -1).sum(axis=1)<EMPTY_THRESHOLD).astype(np.int)\n",
    "#pred[pred_clf.reshape(-1,)==1, ] = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#unet mask model\n",
    "print(classification_report(y_valid[:, 0], 1-_pred_clf[:y_valid.shape[0]]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(1-_pred_clf[:y_valid.shape[0]]).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#classify zero-nonzero model\n",
    "print(classification_report(y_valid, sigmoid(preds_valid)>THRESHOLD))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(sigmoid(preds_valid)>THRESHOLD).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size = [8, 4, 64, 64]\n",
    "n = size[0]//2\n",
    "\n",
    "masks = torch.cat([torch.randint(0, 2, [n]+size[1:]).float(), \n",
    "                   torch.randint(0, 1, [n]+size[1:]).float()], dim=0)\n",
    "truth = (torch.sum(masks.reshape(masks.size()[0], masks.size()[1], -1), dim=2, keepdim=False)!=0).float()\n",
    "logit = torch.randn((size[0], size[1]))\n",
    "\n",
    "logit.shape, truth.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Loss_FUNC = torch.nn.BCEWithLogitsLoss()\n",
    "Loss_FUNC(logit, truth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import copy\n",
    "from sklearn.metrics import confusion_matrix, roc_auc_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def metric(logit, truth):\n",
    "    \"\"\"\n",
    "    AUC score as metric\n",
    "    \"\"\"\n",
    "    pred = sigmoid(logit.cpu().detach().numpy())\n",
    "    truth = truth.cpu().detach().numpy()\n",
    "    ##\n",
    "    THRESHOLD_candidate = np.arange(0.01, 0.99, 0.01)\n",
    "    N = len(THRESHOLD_candidate)\n",
    "    best_threshold = [0.01, 0.01, 0.01, 0.01]\n",
    "    best_score = -1\n",
    "    tn, fp, fn, tp, pos_percent = 0, 0, 0, 0, 0.0\n",
    "\n",
    "    for ch in range(4):\n",
    "        for i in range(N):\n",
    "            THRESHOLD = copy.deepcopy(best_threshold)\n",
    "            THRESHOLD[ch] = THRESHOLD_candidate[i]\n",
    "            _pred = pred>THRESHOLD\n",
    "            _pred, truth = _pred.reshape(-1, 1), truth.reshape(-1, 1)\n",
    "            \n",
    "            _tn, _fp, _fn, _tp = confusion_matrix(truth, _pred).ravel()\n",
    "            _auc = round(roc_auc_score(truth, _pred), 5)\n",
    "            if _tn+_fn==0:\n",
    "                fn_rate = 9999\n",
    "            else:\n",
    "                fn_rate = round(_fn/(_tn+_fn), 5)\n",
    "            _pos_percent = (_tp+_fp)/(_tp+_fp+_tn+_fn)\n",
    "            \n",
    "            if _auc > best_score:\n",
    "                best_threshold = copy.deepcopy(THRESHOLD)\n",
    "                best_score = _auc\n",
    "                tn, fp, fn, tp, pos_percent = _tn, _fp, _fn, _tp, _pos_percent\n",
    "    return np.round(best_threshold, 2), best_score, tn, fp, fn, tp, pos_percent\n",
    "\n",
    "def sigmoid(x):\n",
    "    return 1 / (1 + np.exp(-x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "metric(logit, truth)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
